{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import ch9util\n",
    "from sklearn.ensemble import BaggingClassifier\n",
    "from sklearn.grid_search import GridSearchCV\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "import numpy as np\n",
    "import dautil as dl\n",
    "from IPython.display import HTML\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(action='ignore', category=FutureWarning) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = ch9util.rain_split()\n",
    "clf = BaggingClassifier(base_estimator=DecisionTreeClassifier(\n",
    "    min_samples_leaf=3, max_depth=4), random_state=43)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "params = {\n",
    "    'n_estimators': [320, 640],\n",
    "    'bootstrap_features': [True, False],\n",
    "    'base_estimator__criterion': ['gini', 'entropy']\n",
    "}\n",
    "\n",
    "gscv = GridSearchCV(estimator=clf, param_grid=params,\n",
    "                    cv=5, n_jobs=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "gscv.fit(X_train, y_train)\n",
    "preds = gscv.predict(X_test)\n",
    "ch9util.npy_save('bagging', preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "context = dl.nb.Context('bagging')\n",
    "dl.nb.RcWidget(context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "sp = dl.plotting.Subplotter(2, 2, context)\n",
    "html = ch9util.report_rain(preds, y_test, gscv.best_params_, sp.ax)\n",
    "\n",
    "ntrees = 2 ** np.arange(4, 11)\n",
    "ch9util.plot_validation(sp.next_ax(), gscv.best_estimator_, \n",
    "                        X_train, y_train, 'n_estimators', ntrees)\n",
    "\n",
    "nsamples = 2 ** np.arange(4, 14)\n",
    "ch9util.plot_validation(sp.next_ax(), gscv.best_estimator_, \n",
    "                        X_train, y_train, 'max_samples', nsamples)\n",
    "\n",
    "ch9util.plot_learn_curve(sp.next_ax(), gscv.best_estimator_, \n",
    "                         X_train, y_train)\n",
    "HTML(html + sp.exit())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
